import taichi as ti
from ..camera import Camera
from ..utils import *


class Pinhole(Camera):
    """针孔(光圈)相机"""

    def __init__(
        self, size: "tuple[int, int]", fov: float = 60, lenRadius: float = 0
    ) -> None:
        super().__init__(size)
        self.fov = fov
        self.lenRadius: ti.ScalarField = ti.field(dtype=float, shape=1)
        self.lenRadius[0] = lenRadius
        self.pixelW: ti.MatrixField = Vec3.field(shape=1)
        self.pixelH: ti.MatrixField = Vec3.field(shape=1)
        self.pixelWH: ti.MatrixField = Mat3x2.field(shape=1)
        self.offset: ti.MatrixField = Vec3.field(shape=1)
        self.setScreen()

    def setScreen(self):
        """
        设置屏幕属性如 xyz轴、像素长宽、镜头左下角偏移,
        需要预先准备好 theta, phi, fov 等属性
        """
        sinTheta = ti.sin(self.theta)
        z = Vec3(
            ti.cos(self.phi) * sinTheta, ti.cos(self.theta), ti.sin(self.phi) * sinTheta
        )
        x = Vec3(ti.cos(self.phi - PI / 2), 0, ti.sin(self.phi - PI / 2))
        y = z.cross(x)
        half_height = ti.tan(self.fov * deg2rad / 2)
        half_width = self.aspectRatio * half_height
        self.x, self.y, self.z = Vec4(x, 0), Vec4(y, 0), Vec4(z, 0)
        self.pixelH[0] = 2 * half_height * y / self.h
        self.pixelW[0] = 2 * half_width * x / self.w
        self.pixelWH[0] = Mat2x3([[*self.pixelW[0]], [*self.pixelH[0]]]).transpose()
        self.offset[0] = z - (self.pixelW[0] * self.w + self.pixelH[0] * self.h) / 2
        return self

    @ti.func
    def rayAt(self, uv: Vec2, t: int = 0) -> Ray4:
        rd = self.lenRadius[0] * randUnitVec2()
        offset = self.pixelW[0] * rd.x + self.pixelH[0] * rd.y
        ss = superSampling(self.SSNumber, t)
        return Ray4(
            Vec4(self.origin[0].xyz + offset, self.origin[0].w),
            Vec4(
                (self.offset[0] + self.pixelWH[0] @ (uv + ss) - offset).normalized(),
                -1,
            ),
        )

    def makeSettingPlane(self, gui: ti.ui.Gui) -> bool:
        clear = False
        if (new_value := gui.slider_float("fov", self.fov, 1, 180)) != self.fov:
            self.fov = new_value
            clear = True
        if (
            new_value := gui.slider_float("len radius", self.lenRadius[0], 0, 2)
        ) != self.lenRadius[0]:
            self.lenRadius[0] = new_value
            clear = True
        if clear:
            self.setScreen()
        return clear
